import torch
pretrained_path = "/home/liyuke/lyk_work/segmatch/pretrained/repvit_m1_1_ade20k.pth"
model_param = torch.load(pretrained_path, map_location='cuda')

my_path = "/home/liyuke/lyk_work/segmatch/ckpts_Segonly/segmentation/020"
new_param = {}
unkown_param = {}
for name in ['backbone',"neck","decode_head"]:
    model_path = my_path +'/{}.pth'.format(name)
    part_param = torch.load(model_path, map_location='cuda')
    for key,param in part_param.items():
        new_name = name+"."+key
        # if new_name in model_param["state_dict"].keys():
        new_param[new_name]=param
        # else:
        #     unkown_param[key] = param
out = {"state_dict":new_param}
torch.save(out,my_path+"/seg.pt")
print(unkown_param)


